package com.tencent.angel.sona.ml.classification

import com.tencent.angel.client.AngelPSClient
import com.tencent.angel.ml.core.PSOptimizerProvider
import com.tencent.angel.ml.math2.utils.{LabeledData, RowType}
import com.tencent.angel.mlcore.conf.{MLCoreConf, SharedConf}
import com.tencent.angel.mlcore.variable.VarState
import com.tencent.angel.psagent.{PSAgent, PSAgentContext}
import com.tencent.angel.sona.core.{AngelGraphModel, DriverContext, ExecutorContext, SparkMasterContext}
import com.tencent.angel.sona.ml.PredictorParams
import com.tencent.angel.sona.ml.common._
import com.tencent.angel.sona.ml.evaluation.evaluating.{BinaryClassificationSummaryImpl, MultiClassificationSummaryImpl}
import com.tencent.angel.sona.ml.evaluation.training.ClassificationTrainingStat
import com.tencent.angel.sona.ml.evaluation.{ClassificationSummary, TrainingStat}
import com.tencent.angel.sona.ml.math2.utils.CusLabeledData
import com.tencent.angel.sona.ml.param.shared.HasProbabilityCol
import com.tencent.angel.sona.ml.param._
import com.tencent.angel.sona.ml.util._
import com.tencent.angel.sona.util.ConfUtils
import org.apache.spark.broadcast.Broadcast
import org.apache.spark.internal.Logging
import org.apache.spark.linalg
import org.apache.spark.linalg._
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.functions.{col, lit}
import org.apache.spark.sql.types.{DoubleType, StructType}
import org.apache.spark.sql.{DataFrame, Dataset, Row, SparkSession}
import org.apache.spark.storage.StorageLevel
import org.apache.spark.util.Example
import org.apache.spark.sql.functions.udf

import scala.collection.JavaConverters._
import scala.util.Random

class CusAngelClassifier(override val uid: String)
  extends Classifier[linalg.Vector, CusAngelClassifier, CusAngelClassifierModel]
    with AngelGraphParams with AngelOptParams with HasNumClasses with ClassifierParams
    with DefaultParamsWritable with Logging{
  private var sparkSession: SparkSession = _

  private var driverCtx = DriverContext.get()
  private implicit var psClient: AngelPSClient = _
  private implicit var psAgent: PSAgent = _
  private var sparkMasterCtx: SparkMasterContext = _
  override val sharedConf: SharedConf = driverCtx.sharedConf
  implicit var bcExeCtx: Broadcast[ExecutorContext] = _
  implicit var bcConf: Broadcast[SharedConf] = _
  private var angelModel: AngelGraphModel = _

  def this() = {
    this(Identifiable.randomUID("CusAngelClassification_"))
  }

  def setNumClass(value: Int): this.type = setInternal(numClass, value)

  setDefault(numClass -> MLCoreConf.DEFAULT_ML_NUM_CLASS)

  override def updateFromProgramSetting(): this.type = {
    sharedConf.set(MLCoreConf.ML_IS_DATA_SPARSE, getIsSparse.toString)
    sharedConf.set(MLCoreConf.ML_MODEL_TYPE, getModelType)
    sharedConf.set(MLCoreConf.ML_FEATURE_INDEX_RANGE, getNumFeature.toString)
    sharedConf.set(MLCoreConf.ML_NUM_CLASS, getNumClass.toString)
    sharedConf.set(MLCoreConf.ML_MODEL_SIZE, getModelSize.toString)
    sharedConf.set(MLCoreConf.ML_FIELD_NUM, getNumField.toString)

    sharedConf.set(MLCoreConf.ML_EPOCH_NUM, getMaxIter.toString)
    sharedConf.set(MLCoreConf.ML_LEARN_RATE, getLearningRate.toString)
    sharedConf.set(MLCoreConf.ML_OPTIMIZER_JSON_PROVIDER, classOf[PSOptimizerProvider].getName)
    sharedConf.set(MLCoreConf.ML_NUM_UPDATE_PER_EPOCH, getNumBatch.toString)
    sharedConf.set(MLCoreConf.ML_OPT_DECAY_CLASS_NAME, getDecayClass.toString)
    sharedConf.set(MLCoreConf.ML_OPT_DECAY_ALPHA, getDecayAlpha.toString)
    sharedConf.set(MLCoreConf.ML_OPT_DECAY_BETA, getDecayBeta.toString)
    sharedConf.set(MLCoreConf.ML_OPT_DECAY_INTERVALS, getDecayIntervals.toString)
    sharedConf.set(MLCoreConf.ML_OPT_DECAY_ON_BATCH, getDecayOnBatch.toString)

    this
  }
  override def train(dataset: Dataset[_]): CusAngelClassifierModel = {
      sparkSession = dataset.sparkSession
      psClient = driverCtx.getAngelClient
      psAgent = driverCtx.getPSAgent
      sparkMasterCtx = driverCtx.sparkMasterContext
      sharedConf.set(ConfUtils.ALGO_TYPE, "class")
      // 1. trans Dataset to RDD
      val w = if (!isDefined(weightCol) || $(weightCol).isEmpty) lit(1.0) else col($(weightCol))
      val instances: RDD[Example] =
        dataset.select(col($(labelCol)), w, col($(featuresCol))).rdd.map {
          case Row(label: Double, weight: Double, features: linalg.Vector) => Example(label, weight, features)
        }

      val numTask = instances.getNumPartitions
      psClient.setTaskNum(numTask)

      bcExeCtx = instances.context.broadcast(ExecutorContext(sharedConf, numTask))
      DriverContext.get().registerBroadcastVariables(bcExeCtx)

      // persist RDD if StorageLevel is NONE
      val handlePersistence = dataset.storageLevel == StorageLevel.NONE
      if (handlePersistence) instances.persist(StorageLevel.MEMORY_AND_DISK)

      // 2. create Instrumentation for log info
      val instr = Instrumentation.create(this, instances)
      instr.logParams(this, maxIter)

      // 3. check data configs
      val example = instances.take(1).head.features
      // 3.1 NumFeature check
      if (example.size != getNumFeature && getNumFeature != -1) {
        // has set
        setNumFeatures(Math.max(example.size, getNumFeature))
        log.info("number of feature form data and algorithm setting does not match")
      } else if (example.size != getNumFeature && getNumFeature == -1) {
        // not set
        setDefault(numFeature, example.size)
        log.info("get number of feature form data")
      } else {
        log.info("number of feature form data and algorithm setting match")
      }
      instr.logNamedValue("NumFeatures", getNumFeature)
      // 3.2 better modelType default value for sona
      if (getModelSize == -1) {
        if (example.size < 1e6) {
          setDefault(modelType, RowType.T_DOUBLE_DENSE.toString)
        } else if (example.size < Int.MaxValue) {
          setDefault(modelType, RowType.T_DOUBLE_SPARSE.toString)
        } else {
          setDefault(modelType, RowType.T_DOUBLE_SPARSE_LONGKEY.toString)
        }
      } else {
        example match {
          case _: DenseVector =>
            setDefault(modelType, RowType.T_DOUBLE_DENSE.toString)
          case iv: IntSparseVector if iv.size <= (2.0 * getModelSize) =>
            setDefault(modelType, RowType.T_DOUBLE_DENSE.toString)
          case iv: IntSparseVector if iv.size > (2.0 * getModelSize) =>
            setDefault(modelType, RowType.T_DOUBLE_SPARSE.toString)
          case _: LongSparseVector =>
            setDefault(modelType, RowType.T_DOUBLE_SPARSE_LONGKEY.toString)
        }
      }

      // 3.3 ModelSize check && partitionStat
      val featureStats = new FeatureStats(uid, getModelType, bcExeCtx)
      val partitionStat = if (getModelSize == -1) {
        // not set
        example match {
          case v: DenseVector =>
            setModelSize(v.size)
            instances.mapPartitions(featureStats.partitionStats, preservesPartitioning = true)
              .reduce(featureStats.mergeMap).asScala.toMap
          case _: SparseVector =>
            featureStats.createPSMat(psClient, getNumFeature)
            val partitionStat_ = instances.mapPartitions(featureStats.partitionStatsWithPS, preservesPartitioning = true)
              .reduce(featureStats.mergeMap).asScala.toMap

            val numValidateFeatures = featureStats.getNumValidateFeatures(psAgent)
            setModelSize(numValidateFeatures)
            partitionStat_
        }
      } else {
        // has set
        instances.mapPartitions(featureStats.partitionStats, preservesPartitioning = true)
          .reduce(featureStats.mergeMap).asScala.toMap
      }

      // 3.4 input data format check and better modelType default value after model known
      example match {
        case _: DenseVector =>
          setIsSparse(false)
          setDefault(modelType, RowType.T_DOUBLE_DENSE.toString)
        case iv: IntSparseVector if iv.size <= (2.0 * getModelSize) =>
          setIsSparse(true)
          setDefault(modelType, RowType.T_DOUBLE_DENSE.toString)
        case iv: IntSparseVector if iv.size > (2.0 * getModelSize) =>
          setIsSparse(true)
          setDefault(modelType, RowType.T_DOUBLE_SPARSE.toString)
        case _: LongSparseVector =>
          setIsSparse(true)
          setDefault(modelType, RowType.T_DOUBLE_SPARSE_LONGKEY.toString)
      }

      // update sharedConf
      finalizeConf(psClient)
      bcConf = instances.context.broadcast(sharedConf)
      DriverContext.get().registerBroadcastVariables(bcConf)

      /** *******************************************************************************************/
      implicit val dim: Long = getNumFeature
      val manifoldBuilder = new CusManifoldBuilder(instances, getNumBatch, partitionStat)
      val manifoldRDD = manifoldBuilder.manifoldRDD()

      val globalRunStat: ClassificationTrainingStat = new ClassificationTrainingStat(getNumClass)
      val sparkModel: CusAngelClassifierModel = copyValues(
        new CusAngelClassifierModel(this.uid, getModelName),
        this.extractParamMap())

      sparkModel.setBCValue(bcExeCtx)

      angelModel = sparkModel.angelModel

      angelModel.buildNetwork()
      logInfo(s"graph:${angelModel.graph.toJsonStr}")
      val startCreate = System.currentTimeMillis()
      angelModel.createMatrices(sparkMasterCtx)
      PSAgentContext.get().getPsAgent.refreshMatrixInfo()
      val finishedCreate = System.currentTimeMillis()
      globalRunStat.setCreateTime(finishedCreate - startCreate)

      if (getIncTrain) {
        val path = getInitModelPath
        require(path.nonEmpty, "InitModelPath is null or empty")

        val startLoad = System.currentTimeMillis()
        angelModel.loadModel(sparkMasterCtx, MLUtils.getHDFSPath(path), null)
        val finishedLoad = System.currentTimeMillis()
        globalRunStat.setLoadTime(finishedLoad - startLoad)
      } else {
        val startInit = System.currentTimeMillis()
        angelModel.init(SparkMasterContext(null))
        val finishedInit = System.currentTimeMillis()
        globalRunStat.setInitTime(finishedInit - startInit)
      }

      angelModel.setState(VarState.Ready)

      /** training **********************************************************************************/
      logInfo(s"开始训练(共${getMaxIter}轮)")
      (0 until getMaxIter).foreach { epoch =>
        logInfo(s"第${epoch+1}轮训练开始")
        globalRunStat.clearStat().setAvgLoss(0.0).setNumSamples(0)
        manifoldRDD.foreach { case batch: RDD[Array[CusLabeledData]] =>
          // training one batch
          val trainer = new CusTrainer(bcExeCtx, epoch, bcConf)
          val runStat_ = batch.filter(miniBatch=>miniBatch.length>0).map(miniBatch => {
            trainer.trainOneBatch(miniBatch)
          })
          val count = runStat_.count()
          if(count>0) {
            val runStat = runStat_.reduce(TrainingStat.mergeInBatch)
            logInfo(s"avg loss:${runStat.getAvgLoss}")
            // those code executor on driver
            val startUpdate = System.currentTimeMillis()
            angelModel.update(epoch, 1)
            val finishedUpdate = System.currentTimeMillis()
            runStat.setUpdateTime(finishedUpdate - startUpdate)

            globalRunStat.mergeMax(runStat)
            globalRunStat.printString()
          }

        }
        globalRunStat.addHistLoss()
        println(globalRunStat.printString())
      }
      logInfo("训练完成")
      /** *******************************************************************************************/

      instr.logInfo(globalRunStat.printString())

      sparkModel.setSummary(Some(globalRunStat))
      instr.logSuccess()
//      driverCtx.stopAngelAndPSAgent()
      sparkModel
  }

  override def copy(extra: ParamMap): CusAngelClassifier = defaultCopy(extra)

  def releaseAngelModel(): this.type = {
    if (angelModel != null) {
      angelModel.releaseMode(driverCtx.sparkWorkerContext)
    }

    angelModel = null
    this
  }
}


object CusAngelClassifier extends DefaultParamsReadable[CusAngelClassifier] with Logging {
  override def load(path: String): CusAngelClassifier = super.load(path)
}


class CusAngelClassifierModel(override val uid: String, override val angelModelName: String)
  extends ClassificationModel[linalg.Vector, CusAngelClassifierModel] with AngelSparkModel
    with HasProbabilityCol with PredictorParams with HasNumClasses with MLWritable with Logging {
  @transient implicit override val psClient: AngelPSClient = DriverContext.get().getAngelClient
  override lazy val numFeatures: Long = getNumFeature
  override lazy val numClasses: Int = getNumClass
  override val sharedConf: SharedConf = DriverContext.get().sharedConf
//  var driverContext = DriverContext.get()

  def setProbabilityCol(value: String): this.type = setInternal(probabilityCol, value)

  override def updateFromProgramSetting(): this.type = {
    sharedConf.set(MLCoreConf.ML_IS_DATA_SPARSE, getIsSparse.toString)
    sharedConf.set(MLCoreConf.ML_MODEL_TYPE, getModelType)
    sharedConf.set(MLCoreConf.ML_FIELD_NUM, getNumField.toString)

    sharedConf.set(MLCoreConf.ML_FEATURE_INDEX_RANGE, getNumFeature.toString)
    sharedConf.set(MLCoreConf.ML_OPTIMIZER_JSON_PROVIDER, classOf[PSOptimizerProvider].getName)

    this
  }

  def findSummaryModel(): (CusAngelClassifierModel, String, String) = {
    val model = if ($(probabilityCol).isEmpty && $(predictionCol).isEmpty) {
      copy(ParamMap.empty)
        .setProbabilityCol("probability_" + java.util.UUID.randomUUID.toString)
        .setPredictionCol("prediction_" + java.util.UUID.randomUUID.toString)
    } else if ($(probabilityCol).isEmpty) {
      copy(ParamMap.empty).setProbabilityCol("probability_" + java.util.UUID.randomUUID.toString)
    } else if ($(predictionCol).isEmpty) {
      copy(ParamMap.empty).setPredictionCol("prediction_" + java.util.UUID.randomUUID.toString)
    } else {
      this
    }
    (model, model.getProbabilityCol, model.getPredictionCol)
  }

  def evaluate(dataset: Dataset[_]): ClassificationSummary = {
    val taskNum = dataset.rdd.getNumPartitions
    setNumTask(taskNum)

    // Handle possible missing or invalid prediction columns
    val (summaryModel, probabilityColName, predictionColName) = findSummaryModel()
    if (numClasses > 2) {
      new MultiClassificationSummaryImpl(summaryModel.transform(dataset),
        predictionColName, $(labelCol))
    } else {
      new BinaryClassificationSummaryImpl(summaryModel.transform(dataset),
        probabilityColName, $(labelCol))
    }
  }

  override def transform(dataset: Dataset[_]): DataFrame = {
    transformSchema(dataset.schema, logging = true)

    val taskNum = dataset.rdd.getNumPartitions
    setNumTask(taskNum)

    val featIdx: Int = dataset.schema.fieldIndex($(featuresCol))

    val probabilityColName = if ($(probabilityCol).isEmpty) {
      val value = s"probability_${java.util.UUID.randomUUID.toString}"
      setDefault(probabilityCol, value)
      value
    } else {
      $(probabilityCol)
    }

    val predictionColName = if ($(predictionCol).isEmpty) {
      val value = s"prediction_${java.util.UUID.randomUUID.toString}"
      setDefault(predictionCol, value)
      value
    } else {
      $(predictionCol)
    }

    if (bcValue == null) {
      finalizeConf(psClient)
      bcValue = dataset.rdd.context.broadcast(ExecutorContext(sharedConf, taskNum))
      DriverContext.get().registerBroadcastVariables(bcValue)
    }

    if (bcConf == null) {
      finalizeConf(psClient)
      bcConf = dataset.rdd.context.broadcast(sharedConf)
      DriverContext.get().registerBroadcastVariables(bcConf)
    }

    val predictor = new Predictor(bcValue, featIdx, probabilityColName, predictionColName, bcConf)

    val newSchema: StructType = dataset.schema
      .add(probabilityColName, DoubleType)
      .add(predictionColName, DoubleType)

    val rddRow = dataset.rdd.asInstanceOf[RDD[Row]]
    val rddWithPredicted = rddRow.mapPartitions(predictor.predictRDD, preservesPartitioning = true)
    dataset.sparkSession.createDataFrame(rddWithPredicted, newSchema)
  }

  override def write: MLWriter = new AngelSaverLoader.AngelModelWriter(this)

  override def copy(extra: ParamMap): CusAngelClassifierModel = defaultCopy(extra)

  override def predict(features: linalg.Vector): Double = ???

  override protected def predictRaw(features: linalg.Vector): linalg.Vector = ???
}

object CusAngelClassifierModel extends MLReadable[CusAngelClassifierModel] with Logging {
  private lazy implicit val psClient: AngelPSClient = synchronized {
    DriverContext.get().getAngelClient
  }

  override def read: MLReader[CusAngelClassifierModel] = new AngelSaverLoader
  .AngelModelReader[CusAngelClassifierModel]()
}
